{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Structure Learning in Bayesian Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we show a few examples of Causal Discovery or Structure Learning in pgmpy. pgmpy currently has the following algorithm for causal discovery:\n",
    "\n",
    "1. **PC**: Has $3$ variants original, stable, and parallel. PC is a constraint-based algorithm that utilizes Conditional Independence tests to construct the model.\n",
    "2. **Hill-Climb Search**: Hill-Climb Search is a greedy optimization-based algorithm that makes iterative local changes to the model structure such that it improves the overall score of the model.\n",
    "3. **Greedy Equivalence Search (GES)**: Another score-based method that makes greedy modifications to the model to improve its score iteratively.\n",
    "4. **ExpertInLoop**: An iterative algorithm that combines Conditional Independence testing with expert knowledge. The user or an LLM can act as the expert.\n",
    "5. **Exhaustive Search**: Exhaustive search iterates over all possible network structures on the given variables to find the most optimal one. As it tries to enumerate all possible network structures, it is intractable when the number of variables in the data is large.\n",
    "\n",
    "The following Conditional Independence Tests are available to use with PC algorithm.\n",
    "1. **Discrete Data**: When all variables are discrete/categorical.\n",
    "    1. **Chi-square test**: `ci_test=\"chi_square\"`\n",
    "    2. **G-squared**: `ci_test=\"g_sq\"`\n",
    "    3. **Log-likelihood**: Is equivalent to G-squared test. `ci_test=\"log_likelihood`\n",
    "2. **Continuous Data**: When all variables are continuous/numerical.\n",
    "    1. **Partial Correlation**: `ci_test=\"pearsonr\"`\n",
    "3. **Mixed Data**: When there is a mix of categorical and continuous variables.\n",
    "    1. **Pillai**: `ci_test=\"pillai\"`\n",
    "\n",
    "For Hill-Climb, Exhausitive Search, and GES the following scoring methods can be used:\n",
    "1. **Discrete Data**: When all variables are discrete/categorical. \n",
    "    1. **BIC Score**: `scoring_method=\"bic-d\"`\n",
    "    2. **AIC Score**: `scoring_method=\"aic-d\"`\n",
    "    3. **K2 Score**: `scoring_method=\"k2\"`\n",
    "    4. **BDeU Score**: `scoring_method=\"bdeu\"`\n",
    "    5. **BDs Score**: `scoring_method=\"bds\"`\n",
    "2. **Continuous Data**: When all variables are continuous/numerical.\n",
    "    1. **Log-Likelihood**: `scoring_method=\"ll-g\"`\n",
    "    2. **AIC**: `scoring_method=\"aic-g\"`\n",
    "    3. **BIC**: `scoring_method=\"bic-g\"`\n",
    "3. **Mixed Data**: When there is a mix of discrete and continuous variables.\n",
    "    1. **AIC**: `scoring_method=\"aic-cg\"`\n",
    "    2. **BIC**: `scoring_method=\"bic-cg\"`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Simulate some sample datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "from pgmpy.estimators import PC, HillClimbSearch, GES\n",
    "from pgmpy.utils import get_example_model\n",
    "from pgmpy.metrics import SHD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa35b1a7897646d7ac141e27e04a6165",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/37 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:pgmpy:Probability values don't exactly sum to 1. Differ by: -2.220446049250313e-16. Adjusting values.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>b1191</th>\n",
       "      <th>cspG</th>\n",
       "      <th>eutG</th>\n",
       "      <th>fixC</th>\n",
       "      <th>cspA</th>\n",
       "      <th>yecO</th>\n",
       "      <th>yedE</th>\n",
       "      <th>sucA</th>\n",
       "      <th>cchB</th>\n",
       "      <th>yceP</th>\n",
       "      <th>...</th>\n",
       "      <th>dnaK</th>\n",
       "      <th>folK</th>\n",
       "      <th>ycgX</th>\n",
       "      <th>lacZ</th>\n",
       "      <th>nuoM</th>\n",
       "      <th>dnaG</th>\n",
       "      <th>b1583</th>\n",
       "      <th>mopB</th>\n",
       "      <th>yaeM</th>\n",
       "      <th>ftsJ</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.060641</td>\n",
       "      <td>2.044477</td>\n",
       "      <td>0.341216</td>\n",
       "      <td>1.448399</td>\n",
       "      <td>-0.351716</td>\n",
       "      <td>2.189750</td>\n",
       "      <td>-1.689554</td>\n",
       "      <td>-0.228456</td>\n",
       "      <td>2.871002</td>\n",
       "      <td>-0.433597</td>\n",
       "      <td>...</td>\n",
       "      <td>1.522817</td>\n",
       "      <td>1.645983</td>\n",
       "      <td>1.595502</td>\n",
       "      <td>2.465247</td>\n",
       "      <td>-0.532987</td>\n",
       "      <td>1.126289</td>\n",
       "      <td>-0.302589</td>\n",
       "      <td>0.773483</td>\n",
       "      <td>3.884857</td>\n",
       "      <td>0.729557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.632151</td>\n",
       "      <td>0.964321</td>\n",
       "      <td>0.830229</td>\n",
       "      <td>0.696598</td>\n",
       "      <td>0.639204</td>\n",
       "      <td>0.058108</td>\n",
       "      <td>-0.736189</td>\n",
       "      <td>0.712095</td>\n",
       "      <td>1.467498</td>\n",
       "      <td>0.320727</td>\n",
       "      <td>...</td>\n",
       "      <td>1.222602</td>\n",
       "      <td>1.790727</td>\n",
       "      <td>1.763590</td>\n",
       "      <td>2.945772</td>\n",
       "      <td>-2.532464</td>\n",
       "      <td>1.460699</td>\n",
       "      <td>2.732595</td>\n",
       "      <td>0.097982</td>\n",
       "      <td>2.566064</td>\n",
       "      <td>0.652853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.585766</td>\n",
       "      <td>2.862437</td>\n",
       "      <td>0.922291</td>\n",
       "      <td>0.370000</td>\n",
       "      <td>0.723932</td>\n",
       "      <td>2.487161</td>\n",
       "      <td>-1.916624</td>\n",
       "      <td>-0.300359</td>\n",
       "      <td>2.050980</td>\n",
       "      <td>-0.064301</td>\n",
       "      <td>...</td>\n",
       "      <td>0.599305</td>\n",
       "      <td>1.302091</td>\n",
       "      <td>0.509717</td>\n",
       "      <td>3.090268</td>\n",
       "      <td>-1.745613</td>\n",
       "      <td>0.168043</td>\n",
       "      <td>0.851346</td>\n",
       "      <td>-0.640472</td>\n",
       "      <td>5.800712</td>\n",
       "      <td>0.031888</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.802866</td>\n",
       "      <td>2.277038</td>\n",
       "      <td>0.608559</td>\n",
       "      <td>2.180283</td>\n",
       "      <td>0.116453</td>\n",
       "      <td>2.539035</td>\n",
       "      <td>-1.656839</td>\n",
       "      <td>-1.420540</td>\n",
       "      <td>2.605192</td>\n",
       "      <td>-0.160302</td>\n",
       "      <td>...</td>\n",
       "      <td>2.015663</td>\n",
       "      <td>2.823588</td>\n",
       "      <td>2.101625</td>\n",
       "      <td>3.521299</td>\n",
       "      <td>-1.212391</td>\n",
       "      <td>1.485369</td>\n",
       "      <td>2.449296</td>\n",
       "      <td>2.032116</td>\n",
       "      <td>3.888917</td>\n",
       "      <td>3.036650</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.868548</td>\n",
       "      <td>2.480999</td>\n",
       "      <td>1.079364</td>\n",
       "      <td>2.413862</td>\n",
       "      <td>0.961743</td>\n",
       "      <td>2.280195</td>\n",
       "      <td>-1.740610</td>\n",
       "      <td>-0.472692</td>\n",
       "      <td>3.338063</td>\n",
       "      <td>-0.509144</td>\n",
       "      <td>...</td>\n",
       "      <td>1.142139</td>\n",
       "      <td>1.251832</td>\n",
       "      <td>-0.145789</td>\n",
       "      <td>2.057217</td>\n",
       "      <td>-2.862230</td>\n",
       "      <td>-0.089721</td>\n",
       "      <td>1.020832</td>\n",
       "      <td>0.040589</td>\n",
       "      <td>4.925848</td>\n",
       "      <td>0.575977</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 46 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      b1191      cspG      eutG      fixC      cspA      yecO      yedE  \\\n",
       "0  1.060641  2.044477  0.341216  1.448399 -0.351716  2.189750 -1.689554   \n",
       "1  0.632151  0.964321  0.830229  0.696598  0.639204  0.058108 -0.736189   \n",
       "2  0.585766  2.862437  0.922291  0.370000  0.723932  2.487161 -1.916624   \n",
       "3  1.802866  2.277038  0.608559  2.180283  0.116453  2.539035 -1.656839   \n",
       "4  1.868548  2.480999  1.079364  2.413862  0.961743  2.280195 -1.740610   \n",
       "\n",
       "       sucA      cchB      yceP  ...      dnaK      folK      ycgX      lacZ  \\\n",
       "0 -0.228456  2.871002 -0.433597  ...  1.522817  1.645983  1.595502  2.465247   \n",
       "1  0.712095  1.467498  0.320727  ...  1.222602  1.790727  1.763590  2.945772   \n",
       "2 -0.300359  2.050980 -0.064301  ...  0.599305  1.302091  0.509717  3.090268   \n",
       "3 -1.420540  2.605192 -0.160302  ...  2.015663  2.823588  2.101625  3.521299   \n",
       "4 -0.472692  3.338063 -0.509144  ...  1.142139  1.251832 -0.145789  2.057217   \n",
       "\n",
       "       nuoM      dnaG     b1583      mopB      yaeM      ftsJ  \n",
       "0 -0.532987  1.126289 -0.302589  0.773483  3.884857  0.729557  \n",
       "1 -2.532464  1.460699  2.732595  0.097982  2.566064  0.652853  \n",
       "2 -1.745613  0.168043  0.851346 -0.640472  5.800712  0.031888  \n",
       "3 -1.212391  1.485369  2.449296  2.032116  3.888917  3.036650  \n",
       "4 -2.862230 -0.089721  1.020832  0.040589  4.925848  0.575977  \n",
       "\n",
       "[5 rows x 46 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Discrete variable dataset\n",
    "alarm_model = get_example_model(\"alarm\")\n",
    "alarm_samples = alarm_model.simulate(int(1e3))\n",
    "alarm_samples.head()\n",
    "\n",
    "# Continuous variable dataset\n",
    "ecoli_model = get_example_model(\"ecoli70\")\n",
    "ecoli_samples = ecoli_model.simulate(int(1e3))\n",
    "ecoli_samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to evaluate the learned model structures.\n",
    "def get_f1_score(estimated_model, true_model):\n",
    "    nodes = estimated_model.nodes()\n",
    "    est_adj = nx.to_numpy_array(\n",
    "        estimated_model.to_undirected(), nodelist=nodes, weight=None\n",
    "    )\n",
    "    true_adj = nx.to_numpy_array(\n",
    "        true_model.to_undirected(), nodelist=nodes, weight=None\n",
    "    )\n",
    "\n",
    "    f1 = f1_score(np.ravel(true_adj), np.ravel(est_adj))\n",
    "    print(\"F1-score for the model skeleton: \", f1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. PC algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c277a6a8093a4592a2a3bbf953cb5f27",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy:Reached maximum number of allowed conditional variables. Exiting\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1-score for the model skeleton:  0.825\n"
     ]
    }
   ],
   "source": [
    "# Learning the discrete variable alarm model back\n",
    "\n",
    "est = PC(data=alarm_samples)\n",
    "estimated_model = est.estimate(ci_test='chi_square', variant=\"stable\", max_cond_vars=4, return_type='dag')\n",
    "get_f1_score(estimated_model, alarm_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3b6e02f74ff48fcaa9b4d2efd710852",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy:Reached maximum number of allowed conditional variables. Exiting\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1-score for the model skeleton:  0.640625\n"
     ]
    }
   ],
   "source": [
    "# Learning the continuous variable ecoli model back\n",
    "\n",
    "est = PC(data=ecoli_samples)\n",
    "estimated_model = est.estimate(ci_test='pearsonr', variant=\"orig\", max_cond_vars=4, return_type='dag')\n",
    "get_f1_score(estimated_model, ecoli_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Hill-Climb Search Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb1cfeeaedad4c6fad37acd97fa383b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1-score for the model skeleton:  0.7719298245614035\n"
     ]
    }
   ],
   "source": [
    "# Learning the discrete variable alarm model back\n",
    "\n",
    "est = HillClimbSearch(data=alarm_samples)\n",
    "estimated_model = est.estimate(scoring_method=\"k2\", max_indegree=4, max_iter=int(1e4))\n",
    "get_f1_score(estimated_model, alarm_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5649aa3962c94e7aa2151c1e022d1545",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1-score for the model skeleton:  0.7692307692307693\n"
     ]
    }
   ],
   "source": [
    "# Learning the continuous variable ecoli model back\n",
    "\n",
    "est = HillClimbSearch(data=ecoli_samples)\n",
    "estimated_model = est.estimate(scoring_method=\"bic-g\", max_indegree=4, max_iter=int(1e4))\n",
    "get_f1_score(estimated_model, ecoli_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. GES algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'INTUBATION': 'C', 'ANAPHYLAXIS': 'C', 'VENTLUNG': 'C', 'HISTORY': 'C', 'LVEDVOLUME': 'C', 'VENTMACH': 'C', 'TPR': 'C', 'MINVOL': 'C', 'VENTTUBE': 'C', 'HRBP': 'C', 'BP': 'C', 'PAP': 'C', 'PULMEMBOLUS': 'C', 'DISCONNECT': 'C', 'ERRLOWOUTPUT': 'C', 'HYPOVOLEMIA': 'C', 'CVP': 'C', 'HREKG': 'C', 'CO': 'C', 'PRESS': 'C', 'HRSAT': 'C', 'CATECHOL': 'C', 'EXPCO2': 'C', 'PVSAT': 'C', 'ARTCO2': 'C', 'LVFAILURE': 'C', 'KINKEDTUBE': 'C', 'VENTALV': 'C', 'FIO2': 'C', 'SHUNT': 'C', 'PCWP': 'C', 'SAO2': 'C', 'HR': 'C', 'MINVOLSET': 'C', 'STROKEVOLUME': 'C', 'INSUFFANESTH': 'C', 'ERRCAUTER': 'C'}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1-score for the model skeleton:  0.8444444444444444\n"
     ]
    }
   ],
   "source": [
    "# Learning the discrete variable alarm model back\n",
    "\n",
    "est = GES(data=alarm_samples)\n",
    "estimated_model = est.estimate(scoring_method=\"bic-d\")\n",
    "get_f1_score(estimated_model, alarm_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n",
      "INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: \n",
      " {'b1191': 'N', 'cspG': 'N', 'eutG': 'N', 'fixC': 'N', 'cspA': 'N', 'yecO': 'N', 'yedE': 'N', 'sucA': 'N', 'cchB': 'N', 'yceP': 'N', 'ygbD': 'N', 'yjbO': 'N', 'yfiA': 'N', 'lpdA': 'N', 'pspB': 'N', 'atpG': 'N', 'dnaJ': 'N', 'flgD': 'N', 'gltA': 'N', 'sucD': 'N', 'tnaA': 'N', 'ygcE': 'N', 'yhdM': 'N', 'ibpB': 'N', 'yfaD': 'N', 'hupB': 'N', 'pspA': 'N', 'asnA': 'N', 'atpD': 'N', 'nmpC': 'N', 'icdA': 'N', 'lacA': 'N', 'yheI': 'N', 'aceB': 'N', 'lacY': 'N', 'b1963': 'N', 'dnaK': 'N', 'folK': 'N', 'ycgX': 'N', 'lacZ': 'N', 'nuoM': 'N', 'dnaG': 'N', 'b1583': 'N', 'mopB': 'N', 'yaeM': 'N', 'ftsJ': 'N'}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1-score for the model skeleton:  0.8461538461538461\n"
     ]
    }
   ],
   "source": [
    "# Learning the continuous variable ecoli model back\n",
    "\n",
    "est = GES(data=ecoli_samples)\n",
    "estimated_model = est.estimate(scoring_method=\"bic-g\")\n",
    "get_f1_score(estimated_model, ecoli_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Expert In Loop Algorithm\n",
    "\n",
    "Please refer to the following blogpost for more details: https://medium.com/gopenai/llms-for-causal-discovery-745e2cba0b59"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
